// This file is made available under Elastic License 2.0.
// This file is based on code available under the Apache license here:
//   https://github.com/apache/orc/tree/main/c++/src/ColumnWriter.hh

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef ORC_COLUMN_WRITER_HH
#define ORC_COLUMN_WRITER_HH

#include "BloomFilter.hh"
#include "ByteRLE.hh"
#include "Compression.hh"
#include "Statistics.hh"
#include "orc/Exceptions.hh"
#include "orc/Vector.hh"
#include "wrap/orc-proto-wrapper.hh"

namespace orc {

class StreamsFactory {
public:
    virtual ~StreamsFactory();

    /**
     * Get the stream for the given column/kind in this stripe.
     * @param kind the kind of the stream
     * @return the buffered output stream
     */
    virtual std::unique_ptr<BufferedOutputStream> createStream(proto::Stream_Kind kind) const = 0;
};

std::unique_ptr<StreamsFactory> createStreamsFactory(const WriterOptions& options, OutputStream* outStream);

/**
   * record stream positions for row index
   */
class RowIndexPositionRecorder : public PositionRecorder {
public:
    ~RowIndexPositionRecorder() override;

    RowIndexPositionRecorder(proto::RowIndexEntry& entry) : rowIndexEntry(entry) {}

    void add(uint64_t pos) override { rowIndexEntry.add_positions(pos); }

private:
    proto::RowIndexEntry& rowIndexEntry;
};

/**
   * The interface for writing ORC data types.
   */
class ColumnWriter {
protected:
    std::unique_ptr<ByteRleEncoder> notNullEncoder;
    uint64_t columnId;
    std::unique_ptr<MutableColumnStatistics> colIndexStatistics;
    std::unique_ptr<MutableColumnStatistics> colStripeStatistics;
    std::unique_ptr<MutableColumnStatistics> colFileStatistics;

    bool enableIndex;
    // row index for this column, contains all RowIndexEntries in 1 stripe
    std::unique_ptr<proto::RowIndex> rowIndex;
    std::unique_ptr<proto::RowIndexEntry> rowIndexEntry;
    std::unique_ptr<RowIndexPositionRecorder> rowIndexPosition;

    // bloom filters are recorded per row group
    bool enableBloomFilter;
    std::unique_ptr<BloomFilterImpl> bloomFilter;
    std::unique_ptr<proto::BloomFilterIndex> bloomFilterIndex;

public:
    ColumnWriter(const Type& type, const StreamsFactory& factory, const WriterOptions& options);

    virtual ~ColumnWriter();

    /**
     * Write the next group of values from this rowBatch.
     * @param rowBatch the row batch data to write
     * @param offset the starting point of row batch to write
     * @param numValues the number of values to write
     * @param incomingMask if null, all values are not null. Otherwise, it is
     *                     a mask (with at least numValues bytes) for which
     *                     values to write.
     */
    virtual void add(ColumnVectorBatch& rowBatch, uint64_t offset, uint64_t numValues, const char* incomingMask);
    /**
     * Flush column writer output streams.
     * @param streams vector to store streams generated by flush()
     */
    virtual void flush(std::vector<proto::Stream>& streams);

    /**
     * Get estimated size of buffer used.
     * @return estimated size of buffer used
     */
    virtual uint64_t getEstimatedSize() const;

    /**
     * Get the encoding used by the writer for this column.
     * @param encodings vector to store the returned ColumnEncoding info
     */
    virtual void getColumnEncoding(std::vector<proto::ColumnEncoding>& encodings) const = 0;

    /**
     * Get the stripe statistics for this column.
     * @param stats vector to store the returned stripe statistics
     */
    virtual void getStripeStatistics(std::vector<proto::ColumnStatistics>& stats) const;

    /**
     * Get the file statistics for this column.
     * @param stats vector to store the returned file statistics
     */
    virtual void getFileStatistics(std::vector<proto::ColumnStatistics>& stats) const;

    /**
     * Merge index stats into stripe stats and reset index stats.
     */
    virtual void mergeRowGroupStatsIntoStripeStats();

    /**
     * Merge stripe stats into file stats and reset stripe stats.
     */
    virtual void mergeStripeStatsIntoFileStats();

    /**
     * Create a row index entry with the previous location and the current
     * index statistics. Also merges the index statistics into the stripe
     * statistics before they are cleared. Finally, it records the start of the
     * next index and ensures all of the children columns also create an entry.
     */
    virtual void createRowIndexEntry();

    /**
     * Create a new BloomFilter entry and add the previous one to BloomFilterIndex
     */
    virtual void addBloomFilterEntry();

    /**
     * Write row index streams for this column.
     * @param streams output list of ROW_INDEX streams
     */
    virtual void writeIndex(std::vector<proto::Stream>& streams) const;

    /**
     * Record positions for index.
     *
     * This function is called by createRowIndexEntry() and ColumnWriter's
     * constructor. So base classes do not need to call inherited classes'
     * recordPosition() function.
     */
    virtual void recordPosition() const;

    /**
     * Reset positions for index.
     */
    virtual void reset();

    /**
     * Write dictionary to streams for string columns
     */
    virtual void writeDictionary();

protected:
    /**
     * Utility function to translate ColumnStatistics into protobuf form and
     * add it to output list.
     * @param statsList output list for protobuf stats
     * @param stats ColumnStatistics to be transformed and added
     */
    void getProtoBufStatistics(std::vector<proto::ColumnStatistics>& statsList,
                               const MutableColumnStatistics* stats) const {
        proto::ColumnStatistics pbStats;
        stats->toProtoBuf(pbStats);
        statsList.push_back(pbStats);
    }

protected:
    MemoryPool& memPool;
    std::unique_ptr<BufferedOutputStream> indexStream;
    std::unique_ptr<BufferedOutputStream> bloomFilterStream;
};

/**
   * Create a writer for the given type.
   */
std::unique_ptr<ColumnWriter> buildWriter(const Type& type, const StreamsFactory& factory,
                                          const WriterOptions& options);
} // namespace orc

#endif
